this builds up the ideas of the tabled asymmetric numeral system

In [41]:
from random import choices, random
from math import log2

# generate a random probability distribution
def gen_probs(symbols: list | str):
    return { s: random() for s in symbols }
PROBS = gen_probs("abc")
total = sum(PROBS.values())
for c in PROBS:
    assert isinstance(c, str) and len(c) == 1, f"invalid symbol: '{c}'; all symbols must be single characters"
    PROBS[c] /= total
SYMBOLS, WEIGHTS = list(PROBS.keys()), list(PROBS.values())

# there are many ways you can generate labelings
def spread(freqs: dict, length: int):
    labeling = ""
    total = sum(freqs.values())
    probs = { c: total / freq for c, freq in freqs.items() }
    charset = set(probs.keys())
    while len(labeling) < length - len(charset):
        min_c = min(probs, key=probs.get)
        labeling += min_c
        if min_c in charset: charset.remove(min_c)
        probs[min_c] += total / freqs[min_c]
    for c in charset: labeling += c
    return labeling
BLOCK_LABELING = spread(PROBS, 32)

print("given probs: ", PROBS)
print("generated labeling: ", BLOCK_LABELING)
print("labeling probs: ", { c: BLOCK_LABELING.count(c) / len(BLOCK_LABELING) for c in set(BLOCK_LABELING) })

given probs:  {'a': 0.351811355564492, 'b': 0.004820677449920793, 'c': 0.6433679669855872}
generated labeling:  caccaccaccaccacaccaccaccaccaccab
labeling probs:  {'a': 0.34375, 'b': 0.03125, 'c': 0.625}


In [3]:
# just for comparison
class Huffman:
    def __init__(self, freq_table):
        table = { c: "" for c in freq_table }
        sorted_chars = sorted(freq_table.items(), key=lambda x: x[1])
        while len(sorted_chars) > 1:
            left, right = sorted_chars.pop(0), sorted_chars.pop(0)
            for c in left[0]: table[c] = "0" + table[c]
            for c in right[0]: table[c] = "1" + table[c]
            total_freq = left[1] + right[1]
            new_index = next((i for i, x in enumerate(sorted_chars) if x[1] >= total_freq), len(sorted_chars))
            sorted_chars.insert(new_index, (left[0] + right[0], total_freq))
        self.encoding_table = table
        self.decoding_table = { v: k for k, v in table.items() }
    
    def encode(self, message): return ''.join(self.encoding_table[c] for c in message)
    def decode(self, code):
        message = ""
        code_word = ""
        for c in code:
            code_word += c
            if code_word in self.decoding_table:
                message += self.decoding_table[code_word]
                code_word = ""
        return message

huffman = Huffman(PROBS)

## Base version
counts symbols by blocks

In [4]:
class ANS:
    def __init__(self, labeling: str | list):
        self.labeling = labeling
        self.block_size = len(labeling)

        # `symbol` appears in each block `count_per_block[symbol]` times
        # `count_before_index[i]` is the number of numbers labeled `labeling[i]` that are less than `i`
        # `symbol_table[symbol][i]` is the index in `labeling` of the `i`th number labeled `symbol`
        self.count_per_block = {}
        self.count_before_index = []
        self.symbol_table = {}
        for i, c in enumerate(labeling):
            if c not in self.symbol_table:
                self.count_per_block[c] = 0
                self.symbol_table[c] = []
            self.count_before_index.append(self.count_per_block[c])
            self.count_per_block[c] += 1
            self.symbol_table[c].append(i)

        # if the initial value of `X` is too small, then `C(X)` might equal `X`. but we want `C(X)` to be different from `X` for everything to be reversible
        # this is analagous to "appending a non-zero digit to the left of `X`" except in ANS it's "non-first-symbol symbol" instead
        # if you're sure that no message will ever start with the first symbol, you can set `initial_state = 0`
        self.initial_state = next((i for i, c in enumerate(labeling) if c != labeling[0]), len(labeling))
    
    def C(self, state: int, symbol: str):
        """Returns the `state + 1`th number labeled `symbol`"""

        # full_blocks * count_per_block[symbol] + symbols_left = state + 1
        full_blocks = (state + 1) // self.count_per_block[symbol]
        symbols_left = (state + 1) % self.count_per_block[symbol] # equivalently, (state + 1) - full_blocks * self.count_per_block[symbol]
        if symbols_left == 0:
            full_blocks -= 1
            symbols_left = self.count_per_block[symbol]

        # Count `symbols_left` symbols within the block
        index_within_block = self.symbol_table[symbol][symbols_left - 1]
        
        # if `block_size` is a power of 2, this multiplication is a bitshift
        return full_blocks * self.block_size + index_within_block
    
    def D(self, state: int):
        """Counts the number of numbers labeled `symbol` that are less than `state`"""

        index_within_block = state % self.block_size # if `block_size` is a power of 2, this is `state & (block_size - 1)`
        symbol = self.labeling[index_within_block]

        num_previous_blocks = state // self.block_size # if `block_size` is a power of 2, this division is a bitshift
        count_before_block = self.count_per_block[symbol] * num_previous_blocks
        
        return symbol, count_before_block + self.count_before_index[index_within_block]
    
    def encode(self, message: str | list):
        state = self.initial_state
        for symbol in message[::-1]:
            state = self.C(state, symbol)
        return state
    
    def decode(self, state: int):
        message = ""
        while state > self.initial_state:
            symbol, state = self.D(state)
            message += symbol
        return message, state

In [5]:
ans = ANS(BLOCK_LABELING)

message = "".join(choices(SYMBOLS, WEIGHTS, k = 40000))
information_content = -sum(log2(PROBS[symbol]) for symbol in message)

state = ans.encode(message)

ans_length = state.bit_length() - 1 # the first bit is always 1, so it's not counted
huffman_length = len(huffman.encode(message))

print(f"information content: {information_content}")
print(f"code length: {ans_length} ({ans_length / information_content - 1:.3%} longer than information)")
print(f"huffman: {huffman_length} ({huffman_length / information_content - 1:.3%} longer than information)")

decoded, state = ans.decode(state)
assert decoded == message, message + ", " + decoded

information content: 62387.958329509194
code length: 62409 (0.034% longer than information)
huffman: 63556 (1.872% longer than information)


## Streamed version
incorporates renormalization

`state` will now always be in `[block_size, block_size * 2 - 1]`. This means we have to initialize `state` to a number in that range, but the decoder doesn't need to know `initial_state` anymore, so you could store some data in `initial_state` that the decoder will reproduce at the end

With streaming, we take a hit on compression rate

In [34]:
class StreamANS:
    def __init__(self, labeling: str | list):
        self.labeling = labeling
        self.block_size = len(labeling)

        self.count_per_block = {}
        self.count_before_index = []
        self.symbol_table = {}
        for i, c in enumerate(labeling):
            if c not in self.symbol_table:
                self.count_per_block[c] = 0
                self.symbol_table[c] = []
            self.count_before_index.append(self.count_per_block[c])
            self.count_per_block[c] += 1
            self.symbol_table[c].append(i)
    
    def C(self, state: int, symbol: str):
        """Returns the `state + 1`th number labeled `symbol`"""

        full_blocks = (state + 1) // self.count_per_block[symbol]
        symbols_left = (state + 1) % self.count_per_block[symbol]
        if symbols_left == 0:
            full_blocks -= 1
            symbols_left = self.count_per_block[symbol]

        # Count `symbols_left` symbols within the block
        index_within_block = self.symbol_table[symbol][symbols_left - 1]
        
        return full_blocks * self.block_size + index_within_block
    
    def D(self, state: int):
        """Counts the number of numbers labeled `symbol` that are less than `state`"""

        # because of renormalization, `state` is guaranteed to be in [block_size, 2 * block_size - 1]
        index_within_block = state - self.block_size
        symbol = self.labeling[index_within_block]
        
        return symbol, self.count_per_block[symbol] + self.count_before_index[index_within_block]
    
    def encode(self, message: str | list, initial_state: int = 0):
        bitstream = 1
        state = [i for i in range(self.block_size) if self.C(i, message[-1]) >= self.block_size][initial_state] # the `initial_state`th number X such that `block_size <= C(X) < 2 * block_size`

        for symbol in message[::-1]:

            # shrink `X` until `C(X) < 2N`
            predicted = self.C(state, symbol)
            normalized_state = state
            while predicted >= 2 * self.block_size:
                bitstream <<= 1
                bitstream |= normalized_state & 1
                normalized_state >>= 1
                predicted = self.C(normalized_state, symbol)
            
            state = predicted
        
        return state, bitstream
    
    def decode(self, state: int, bitstream):
        message = ""
        
        while True:
            symbol, state = self.D(state)
            message += symbol

            # expand `X` until `X >= N` or we run out of bits
            while state < self.block_size and bitstream > 1:
                state <<= 1
                state |= bitstream & 1
                bitstream >>= 1
            if state < self.block_size: break
        
        return message, state

In [43]:
ans = StreamANS(BLOCK_LABELING)

message = "".join(choices(SYMBOLS, WEIGHTS, k = 120000))
information_content = -sum(log2(PROBS[symbol]) for symbol in message)

state, bitstring = ans.encode(message)

ans_length = len(BLOCK_LABELING).bit_length() + bitstring.bit_length() - 1
huffman_length = len(huffman.encode(message))

print(f"information content: {information_content}")
print(f"code length: {ans_length} ({ans_length / information_content - 1:.3%} longer than information)")
print(f"huffman: {huffman_length} ({huffman_length / information_content - 1:.3%} longer than information)")

decoded, state = ans.decode(state, bitstring)
assert decoded == message, message + ", " + decoded

information content: 117591.35184296229
code length: 119466 (1.594% longer than information)
huffman: 239401 (103.587% longer than information)


## Tabled version
`decode` and `encode` don't use `C` or `D` anymore; they only reference `renormalization_table`, `encoding_table`, and `decoding_table`, which are generated at initialization. The table generation I'm doing is slow; you should probably look at other implementations to see how they do it

Because we're storing the number of bits to renormalize with in the table, this only works when `block_size` is a power of 2(?); otherwise, encoding the same symbol into the same state might not pop the same number of bits to the bitstream. Duda mentions this in his paper; I'm not exactly sure how he does it

In [52]:
class tANS:
    def __init__(self, labeling: str | list):
        self.labeling = labeling
        self.block_size = len(labeling)
        self.block_mask = self.block_size - 1 # X & block_mask = X % block_size

        # these are now only used during initialization to generate tables
        self.count_per_block = {}
        self.count_before_index = []
        self.symbol_table = {}
        for i, c in enumerate(labeling):
            if c not in self.symbol_table:
                self.count_per_block[c] = 0
                self.symbol_table[c] = []
            self.count_before_index.append(self.count_per_block[c])
            self.count_per_block[c] += 1
            self.symbol_table[c].append(i)

        self.generate_tables()
    
    # `C` and `D` are only used during initialization
    def C(self, state: int, symbol: str):
        """Returns the `state + 1`th number labeled `symbol`"""

        full_blocks = (state + 1) // self.count_per_block[symbol]
        symbols_left = (state + 1) % self.count_per_block[symbol]
        if symbols_left == 0:
            full_blocks -= 1
            symbols_left = self.count_per_block[symbol]

        index_within_block = self.symbol_table[symbol][symbols_left - 1]
        
        return full_blocks * self.block_size + index_within_block
    
    def D(self, state: int):
        """Counts the number of numbers labeled `symbol` that are less than `state`"""

        index_within_block = state & self.block_mask
        symbol = self.labeling[index_within_block]
        count_within_block = sum(c == symbol for c in self.labeling[:index_within_block])
        
        return symbol, self.count_per_block[symbol] + count_within_block
    
    # the way i'm generating these is bad; i just wrote it like this to make it more obvious what it's doing
    def generate_tables(self):
        # i'm not exactly sure about the proof, but encoding any one symbol `symbol` will always result in either `n_bits_chopped` or `n_bits_chopped + 1` being output into the bitstream during renormalization
        # it's `n_bits_chopped + 1` when the state is over or equal to `min_state`. except here i use `min_state = 0` as a special case where the number of bits output is never `n_bits_chopped + 1`
        renormalization_table = {}
        for symbol in self.symbol_table.keys():
            n_bits_chopped = -1
            min_state = 0
            prev_n_bits = -1
            for state in range(self.block_size, self.block_size * 2):
                renormalized = state
                n_bits = 0
                while self.C(renormalized, symbol) >= self.block_size * 2:
                    renormalized >>= 1
                    n_bits += 1
                if n_bits != prev_n_bits:
                    prev_n_bits = n_bits
                    if n_bits_chopped < 0:
                        n_bits_chopped = n_bits
                    else:
                        assert min_state == 0, "uh oh"
                        min_state = state
            renormalization_table[symbol] = (n_bits_chopped, min_state, (1 << n_bits_chopped) - 1)
        self.renormalization_table = renormalization_table

        encoding_table = { c: {} for c in self.symbol_table.keys() }
        decoding_table = [0 for i in range(self.block_size)]
        for state in range(self.block_size, self.block_size * 2):
            symbol, old_state = self.D(state)
            encoding_table[symbol][old_state] = state
            n_bits = 0
            scaled_state = old_state
            while scaled_state < self.block_size:
                scaled_state <<= 1
                n_bits += 1
            decoding_table[state - self.block_size] = (old_state << n_bits, n_bits, (1 << n_bits) - 1)
        self.encoding_table = encoding_table
        self.decoding_table = decoding_table
    
    def encode(self, message: str | list, initial_state: int = 0):
        bitstream = 1
        state = [i for i, c in enumerate(self.labeling) if c == message[-1]][initial_state] # the `initial_state`th number labeled `message[-1]`
        assert self.decoding_table[state][1] > 0, "code is ambiguous because initial state is too large" # we want this symbol to prompt the decoder to scale back up so that it can run out of bits and terminate

        state += self.block_size
        for symbol in message[-2::-1]:
            n_bits, min_state, mask = self.renormalization_table[symbol]
            if min_state > 0 and state >= min_state:
                n_bits += 1
                mask = (mask << 1) | 1
            bitstream = (bitstream << n_bits) | (state & mask)
            state = self.encoding_table[symbol][state >> n_bits]
        
        return state, bitstream
    
    def decode(self, state: int, bitstream):
        message = ""
        
        while True:
            state -= self.block_size
            message += self.labeling[state]
            (state, num_bits, mask) = self.decoding_table[state]
            if num_bits > 0 and bitstream <= 1: break
            state |= bitstream & mask
            bitstream >>= num_bits
        
        return message, state

In [53]:
ans = tANS(BLOCK_LABELING)

message = "".join(choices(SYMBOLS, WEIGHTS, k = 150000))
information_content = -sum(log2(PROBS[symbol]) for symbol in message)

state, bitstring = ans.encode(message)

ans_length = len(BLOCK_LABELING).bit_length() + bitstring.bit_length() - 1
huffman_length = len(huffman.encode(message))

print(f"information content: {information_content}")
print(f"code length: {ans_length} ({ans_length / information_content - 1:.3%} longer than information)")
print(f"huffman: {huffman_length} ({huffman_length / information_content - 1:.3%} longer than information)")

decoded, state = ans.decode(state, bitstring)
assert decoded == message, message + ", " + decoded

information content: 146600.72253644982
code length: 148955 (1.606% longer than information)
huffman: 299259 (104.132% longer than information)
