# Explore ANS using toy examples

Given the Decimal System:
- Alphabet: $X = \{0, \dots, 9\}$, uniformly distributed

The Entropy per symbol: $$H_p(X_i) = E_p[-\log_2 P(X_i)] = E_p[-\log_2 P(\dfrac{1}{10})] = \log_2 10 = 3.32 \text{ bits}$$

The expected code word length using optimal symbol encoding (average depth of the tree for alphabet X): $$E_p[l(X_i)] = 3.4 \text{ bits}$$

Can we achieve a better compression than the optimal symbol code?


In [1]:
from typing import Iterator

def encode_better_than_symbol_coding(msg:list[int], base:int)->int:
    #create integer to represent the compressed message
    compressed = 1
    for symb in msg:
        assert symb < base, f"Symbol {symb} is greater than base {base}"
        compressed  = compressed * base + symb # multiply by base and add the new symbol
    return compressed

def decode_better_than_symbol_coding(compressed:int, base:int)->Iterator[int]:
    while compressed != 1: # while the compressed number is not 1 (starting number)
        yield compressed % base # get the last digit
        compressed = compressed // base # remove the last digit


initial_msg, base = [2,3,4,5,6,7,8,9,2,3,4,5,6,7,8,9,2,3,4,5], 10
e = encode_better_than_symbol_coding(initial_msg, base)
d = decode_better_than_symbol_coding(e, base)

print(f"initial_msg: {initial_msg}")
print(f"encoded: {e} | binary representation: {e:b}")
print(f"decoded: {list(d)[::-1]}")


print(f"bitrate: {e.bit_length() / len(initial_msg) :.2f}")

initial_msg: [2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5]
encoded: 123456789234567892345 | binary representation: 1101011000101001110100111111011010011101100101100000100110101111001
decoded: [2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5]
bitrate: 3.35


**Observations of the above compression algo**: Conversion between different numeral systems
1. operates as a stack (LIFO)
2. amortizes compressed bits over symbols (could be better than symbol codes)
3. optimally compresses a sequence of symbols if these symbols:
    - are from the same alphabet
    - uniformly distributed over the alphabet thus statistically independent iid (no correlations between symbols) 


### Improving our coding algorithm

In [2]:
from typing import Iterator

class UniformCoder:
    def __init__(self):
        self.compressed = 1

    def encode(self, symbol:int, base:int)->int:
        #create integer to represent the compressed message
        assert symbol < base, f"Symbol {symbol} is greater than base {base}"
        self.compressed = self.compressed * base + symbol

    def decode(self, base:int)->Iterator[int]:
        symbol = self.compressed % base # get the last digit
        self.compressed //= base # remove the last digit
        return symbol
    


# initial message
import random
initial_msg = [random.randint(0, 9) for _ in range(5)]
coder = UniformCoder()
print(f"Initial message: {initial_msg}")

# Encode the message with different bases
print("\nEncoding process:")
# Base must be at least symbol + 1
bases = [symbol + 1 for symbol in initial_msg]

for symbol, base in zip(initial_msg, bases):
    coder.encode(symbol, base)
    print(f"Symbol: {symbol}, Base: {base} -> Compressed: {coder.compressed}")

print(f"\nFinal encoded value: {coder.compressed}")
print(f"Binary representation: {coder.compressed:b}")
e = coder.compressed.bit_length()

# Decode the message
print("\nDecoding process:")
decoded = []

for base in reversed(bases):
    symbol = coder.decode(base)
    decoded.append(symbol)
    print(f"Base: {base} -> Symbol: {symbol}")
print(f"\nDecoded message: {decoded[::-1]}")
print(f"Bitrate: {e / len(initial_msg) :.2f} bits per symbol")

Initial message: [5, 1, 3, 9, 4]

Encoding process:
Symbol: 5, Base: 6 -> Compressed: 11
Symbol: 1, Base: 2 -> Compressed: 23
Symbol: 3, Base: 4 -> Compressed: 95
Symbol: 9, Base: 10 -> Compressed: 959
Symbol: 4, Base: 5 -> Compressed: 4799

Final encoded value: 4799
Binary representation: 1001010111111

Decoding process:
Base: 5 -> Symbol: 4
Base: 10 -> Symbol: 9
Base: 4 -> Symbol: 3
Base: 2 -> Symbol: 1
Base: 6 -> Symbol: 5

Decoded message: [5, 1, 3, 9, 4]
Bitrate: 2.60 bits per symbol


Observations:
- you DO NOT need to encode every symbol from the same alphabet (you can reduce/change the base to compress better)
    - this drastically reduces/improves the bitrate
- the symbols have to be uniformly distributed under the alphabet

### ANS Coder

In [8]:
class ANSCoder:
    def __init__(self, precision:int):
        self.n = 2**precision
        self.precision = precision
        self.mask = (1 << precision) - 1
        # uniform coder
        self.compressed = 1


    def encode(self, symbol:int, scaled_prob:list[float])->int:
        # decode(scaled_prob[symbol])
        # If scaled_prob[symbol] is a power of 2, we can use bitwise AND
        # For example, if scaled_prob[symbol] = 2^k, then x % 2^k = x & (2^k - 1)
        if scaled_prob[symbol] & (scaled_prob[symbol] - 1) == 0:  # Check if power of 2
            # bitwise version
            z = self.compressed & (scaled_prob[symbol] - 1)
            self.compressed >>= scaled_prob[symbol].bit_length() - 1
        else:
            # original
            z = self.compressed % scaled_prob[symbol]
            self.compressed //= scaled_prob[symbol]
        

        for p in scaled_prob[:symbol]:
            z += p

        # encode(z, self.n)
        # Note this is a slow operation: self.compressed = self.compressed * self.n + z
        self.compressed = (self.compressed << self.precision) + z
        return z



    def decode(self, scaled_prob:list[float])->int:
        # decode(self.n)
        # Note this is a slow operation: z =self.compressed % self.n
        z = self.compressed & self.mask
        # Note this is a slow operation: self.compressed //= self.n
        self.compressed >>= self.precision

        for i,p in enumerate(scaled_prob):
            if p > z:
                symbol = i
                break
            else:
                z -= p
        # encode(symbol, scaled_prob[symbol])
        self.compressed = self.compressed * scaled_prob[symbol] + z
        return symbol
    


import random

# Generate random message
initial_msg = [random.randint(0, 9) for _ in range(5)]
print(f"Initial message: {initial_msg}")

# Initialize ANS coder
precision = 8
ans_coder = ANSCoder(precision)

# Define probabilities for digits 0-9 (equal probability)
scaled_probs = [2**4] * 10  # Each digit gets equal probability

# Encode the message
encoded_values = []
for symbol in initial_msg:
    encoded = ans_coder.encode(symbol, scaled_probs)
    encoded_values.append(encoded)
print(f"Encoded values: {encoded_values}")

print(f"\nFinal encoded value: {ans_coder.compressed}")
print(f"Binary representation: {ans_coder.compressed:b}")
e = ans_coder.compressed.bit_length()

# Decode the message
decoded_msg = []
for _ in range(len(initial_msg)):
    decoded = ans_coder.decode(scaled_probs)
    decoded_msg.append(decoded)
print(f"Decoded message: {decoded_msg[::-1]}")

# Verify correctness
print(f"Success: {decoded_msg[::-1] == initial_msg}")
print(f"Bitrate: {e / len(initial_msg) :.2f} bits per symbol")

    


Initial message: [2, 3, 1, 6, 7]
Encoded values: [33, 49, 17, 97, 113]

Final encoded value: 2299505
Binary representation: 1000110001011001110001
Decoded message: [2, 3, 1, 6, 7]
Success: True
Bitrate: 4.40 bits per symbol


### ANS Coder with improved runtime 

In [22]:
class StreamANSCoder:
    def __init__(self, precision:int, compressed:int=None):
        self.precision = precision
        self.mask = (1 << precision) - 1

        if compressed is None:
            self.bulk = []
            self.head = 1
        else:
            self.bulk = compressed.copy()
            self.head = 0

            if len(self.bulk) != 0:
                self.head = self.bulk.pop()

            if len(self.bulk) != 0:
                self.head = (self.head << self.precision) | self.bulk.pop()


    def encode(self, symbol:int, scaled_prob:list[float])->int:
        if (self.head >> self.precision) >= scaled_prob[symbol]:
            self.bulk.append(self.head & self.mask)
            self.head >>= self.precision

        z = self.head % scaled_prob[symbol]
        self.head //= scaled_prob[symbol]

        for p in scaled_prob[:symbol]:
            z += p

        self.head = (self.head << self.precision) | z
        return z
        
    def decode(self, scaled_prob:list[float])->int:

        z = self.head & self.mask
        self.head >>= self.precision

        for i,p in enumerate(scaled_prob):
            if p > z:
                symbol = i
                break
            else:
                z -= p
        
        self.head = self.head * scaled_prob[symbol] + z

        if (self.head >> self.precision) == 0 and len(self.bulk) != 0:
            self.head = (self.head << self.precision) | self.bulk.pop()

        return symbol
    
    def get_compressed(self)->int:
        compressed, head = self.bulk.copy(), self.head

        while head != 0:
            compressed.append(head & self.mask)
            head >>= self.precision
        return compressed
        
    

In [25]:
# Example usage of StreamANSCoder
# Let's encode and decode a simple message

print("=" * 60)
print("STREAMING ANS ENCODER/DECODER DEMONSTRATION")
print("=" * 60)

# Define a simple probability distribution for symbols 0, 1, 2
# These should be scaled to integers that sum to 2^precision
precision = 8
total_prob = 1 << precision  # 256

# Example: symbol 0 has prob 0.5, symbol 1 has prob 0.3, symbol 2 has prob 0.2
scaled_prob = [128, 77, 51]  # These sum to 256
prob_percentages = [p/total_prob*100 for p in scaled_prob]

# Create encoder
encoder = StreamANSCoder(precision)

# Message to encode
message = [0, 1, 2, 0, 1, 0, 2, 1]

print(f"\n📝 Original message: {message}")
print(f"📊 Probability distribution:")
for i, (prob, pct) in enumerate(zip(scaled_prob, prob_percentages)):
    print(f"   Symbol {i}: {prob}/{total_prob} ({pct:.1f}%)")

print(f"\n🔧 Encoding (reverse order)...")
print("-" * 40)

# Encode the message (note: ANS encodes in reverse order)
for i, symbol in enumerate(reversed(message)):
    z = encoder.encode(symbol, scaled_prob)
    print(f"  Step {i+1}: Encoded symbol {symbol} → z = {z}")

# Get compressed representation
compressed = encoder.get_compressed()
print(f"\n📦 Compressed data: {compressed}")

# Calculate compression statistics
original_bits = len(message) * precision  # Assuming each symbol takes precision bits uncompressed
compressed_bits = len(compressed) * precision
bitrate = compressed_bits / len(message)
compression_ratio = original_bits / compressed_bits

print(f"\n📈 COMPRESSION STATISTICS")
print("-" * 40)
print(f"  Original size:     {original_bits:3d} bits ({len(message)} symbols × {precision} bits/symbol)")
print(f"  Compressed size:   {compressed_bits:3d} bits ({len(compressed)} values × {precision} bits/value)")
print(f"  Bitrate:           {bitrate:.2f} bits per symbol")
print(f"  Compression ratio: {compression_ratio:.2f}:1")

print(f"\n🔓 Decoding...")
print("-" * 40)

# Create decoder with the compressed data
decoder = StreamANSCoder(precision, compressed)

# Decode the message
decoded_message = []
for i in range(len(message)):
    symbol = decoder.decode(scaled_prob)
    decoded_message.append(symbol)
    print(f"  Step {i+1}: Decoded symbol {symbol}")

print(f"\n✅ RESULTS")
print("-" * 40)
print(f"  Original:  {message}")
print(f"  Decoded:   {decoded_message}")
print(f"  Match:     {'✓ SUCCESS' if message == decoded_message else '✗ FAILED'}")
print("=" * 60)


STREAMING ANS ENCODER/DECODER DEMONSTRATION

📝 Original message: [0, 1, 2, 0, 1, 0, 2, 1]
📊 Probability distribution:
   Symbol 0: 128/256 (50.0%)
   Symbol 1: 77/256 (30.1%)
   Symbol 2: 51/256 (19.9%)

🔧 Encoding (reverse order)...
----------------------------------------
  Step 1: Encoded symbol 1 → z = 129
  Step 2: Encoded symbol 2 → z = 232
  Step 3: Encoded symbol 0 → z = 104
  Step 4: Encoded symbol 1 → z = 203
  Step 5: Encoded symbol 0 → z = 75
  Step 6: Encoded symbol 2 → z = 213
  Step 7: Encoded symbol 1 → z = 151
  Step 8: Encoded symbol 0 → z = 23

📦 Compressed data: [213, 23, 5]

📈 COMPRESSION STATISTICS
----------------------------------------
  Original size:      64 bits (8 symbols × 8 bits/symbol)
  Compressed size:    24 bits (3 values × 8 bits/value)
  Bitrate:           3.00 bits per symbol
  Compression ratio: 2.67:1

🔓 Decoding...
----------------------------------------
  Step 1: Decoded symbol 0
  Step 2: Decoded symbol 1
  Step 3: Decoded symbol 2
  Step 4: 