# Explore ANS using toy examples

Given the Decimal System:
- Alphabet: $X = \{0, \dots, 9\}$, uniformly distributed

The Entropy per symbol: $$H_p(X_i) = E_p[-\log_2 P(X_i)] = E_p[-\log_2 P(\dfrac{1}{10})] = \log_2 10 = 3.32 \text{ bits}$$

The expected code word length using optimal symbol encoding (average depth of the tree for alphabet X): $$E_p[l(X_i)] = 3.4 \text{ bits}$$

Can we achieve a better compression than the optimal symbol code?


In [8]:
from typing import Iterator

def encode_better_than_symbol_coding(msg:list[int], base:int)->int:
    #create integer to represent the compressed message
    compressed = 1
    for symb in msg:
        assert symb < base, f"Symbol {symb} is greater than base {base}"
        compressed  = compressed * base + symb # multiply by base and add the new symbol
    return compressed

def decode_better_than_symbol_coding(compressed:int, base:int)->Iterator[int]:
    while compressed != 1: # while the compressed number is not 1 (starting number)
        yield compressed % base # get the last digit
        compressed = compressed // base # remove the last digit


initial_msg, base = [2,3,4,5,6,7,8,9,2,3,4,5,6,7,8,9,2,3,4,5], 10
e = encode_better_than_symbol_coding(initial_msg, base)
d = decode_better_than_symbol_coding(e, base)

print(f"initial_msg: {initial_msg}")
print(f"encoded: {e} | binary representation: {e:b}")
print(f"decoded: {list(d)[::-1]}")


print(f"bitrate: {e.bit_length() / len(initial_msg) :.2f}")

initial_msg: [2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5]
encoded: 123456789234567892345 | binary representation: 1101011000101001110100111111011010011101100101100000100110101111001
decoded: [2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5]
bitrate: 3.35


**Observations of the above compression algo**: Conversion between different numeral systems
1. operates as a stack (LIFO)
2. amortizes compressed bits over symbols (could be better than symbol codes)
3. optimally compresses a sequence of symbols if these symbols:
    - are from the same alphabet
    - uniformly distributed over the alphabet thus statistically independent iid (no correlations between symbols) 


### Improving our coding algorithm

In [10]:
from typing import Iterator

class UniformCoder:
    def __init__(self):
        self.compressed = 1

    def encode(self, symbol:int, base:int)->int:
        #create integer to represent the compressed message
        assert symbol < base, f"Symbol {symbol} is greater than base {base}"
        self.compressed = self.compressed * base + symbol

    def decode(self, base:int)->Iterator[int]:
        symbol = self.compressed % base # get the last digit
        self.compressed //= base # remove the last digit
        return symbol
    


# initial message
import random
initial_msg = [random.randint(0, 9) for _ in range(5)]
coder = UniformCoder()
print(f"Initial message: {initial_msg}")

# Encode the message with different bases
print("\nEncoding process:")
# Base must be at least symbol + 1
bases = [symbol + 1 for symbol in initial_msg]

for symbol, base in zip(initial_msg, bases):
    coder.encode(symbol, base)
    print(f"Symbol: {symbol}, Base: {base} -> Compressed: {coder.compressed}")

print(f"\nFinal encoded value: {coder.compressed}")
print(f"Binary representation: {coder.compressed:b}")
e = coder.compressed.bit_length()

# Decode the message
print("\nDecoding process:")
decoded = []

for base in reversed(bases):
    symbol = coder.decode(base)
    decoded.append(symbol)
    print(f"Base: {base} -> Symbol: {symbol}")
print(f"\nDecoded message: {decoded[::-1]}")
print(f"Bitrate: {e / len(initial_msg) :.2f} bits per symbol")

Initial message: [4, 2, 2, 1, 2]

Encoding process:
Symbol: 4, Base: 5 -> Compressed: 9
Symbol: 2, Base: 3 -> Compressed: 29
Symbol: 2, Base: 3 -> Compressed: 89
Symbol: 1, Base: 2 -> Compressed: 179
Symbol: 2, Base: 3 -> Compressed: 539

Final encoded value: 539
Binary representation: 1000011011

Decoding process:
Base: 3 -> Symbol: 2
Base: 2 -> Symbol: 1
Base: 3 -> Symbol: 2
Base: 3 -> Symbol: 2
Base: 5 -> Symbol: 4

Decoded message: [4, 2, 2, 1, 2]
Bitrate: 2.00 bits per symbol


Observations:
- you DO NOT need to encode every symbol from the same alphabet (you can reduce/change the base to compress better)
    - this drastically reduces/improves the bitrate
- the symbols have to be uniformly distributed under the alphabet

In [51]:
class ANSCoder:
    def __init__(self, precision:int):
        self.n = 2**precision
        self.precision = precision
        self.mask = (1 << precision) - 1
        # uniform coder
        self.compressed = 1


    def encode(self, symbol:int, scaled_prob:list[float])->int:
        # decode(scaled_prob[symbol])
        # If scaled_prob[symbol] is a power of 2, we can use bitwise AND
        # For example, if scaled_prob[symbol] = 2^k, then x % 2^k = x & (2^k - 1)
        if scaled_prob[symbol] & (scaled_prob[symbol] - 1) == 0:  # Check if power of 2
            # bitwise version
            z = self.compressed & (scaled_prob[symbol] - 1)
        else:
            # original
            z = self.compressed % scaled_prob[symbol]
            
        if scaled_prob[symbol] & (scaled_prob[symbol] - 1) == 0:  # Check if power of 2
            # bitwise version
            self.compressed >>= scaled_prob[symbol].bit_length() - 1
        else:
            # original
            self.compressed //= scaled_prob[symbol]

        for p in scaled_prob[:symbol]:
            z += p

        # encode(z, self.n)
        # Note this is a slow operation: self.compressed = self.compressed * self.n + z
        self.compressed = (self.compressed << self.precision) + z
        return z



    def decode(self, scaled_prob:list[float])->int:
        # decode(self.n)
        # Note this is a slow operation: z =self.compressed % self.n
        z = self.compressed & self.mask
        # Note this is a slow operation: self.compressed //= self.n
        self.compressed >>= self.precision

        for i,p in enumerate(scaled_prob):
            if p > z:
                symbol = i
                break
            else:
                z -= p
        # encode(symbol, scaled_prob[symbol])
        self.compressed = self.compressed * scaled_prob[symbol] + z
        return symbol

import random

# Generate random message
initial_msg = [random.randint(0, 9) for _ in range(5)]
print(f"Initial message: {initial_msg}")

# Initialize ANS coder
precision = 8
ans_coder = ANSCoder(precision)

# Define probabilities for digits 0-9 (equal probability)
scaled_probs = [2**4] * 10  # Each digit gets equal probability

# Encode the message
encoded_values = []
for symbol in initial_msg:
    encoded = ans_coder.encode(symbol, scaled_probs)
    encoded_values.append(encoded)
print(f"Encoded values: {encoded_values}")

print(f"\nFinal encoded value: {ans_coder.compressed}")
print(f"Binary representation: {ans_coder.compressed:b}")
e = ans_coder.compressed.bit_length()

# Decode the message
decoded_msg = []
for _ in range(len(initial_msg)):
    decoded = ans_coder.decode(scaled_probs)
    decoded_msg.append(decoded)
print(f"Decoded message: {decoded_msg[::-1]}")

# Verify correctness
print(f"Success: {decoded_msg[::-1] == initial_msg}")
print(f"Bitrate: {e / len(initial_msg) :.2f} bits per symbol")


Initial message: [0, 7, 0, 5, 6]
Encoded values: [1, 113, 1, 81, 97]

Final encoded value: 460129
Binary representation: 1110000010101100001
Decoded message: [0, 7, 0, 5, 6]
Success: True
Bitrate: 3.80 bits per symbol
