# Understanding Asymmetric Numerical Systems (ANS, specifically rANS)

ANS is a family of entropy codes developed by Jarek Duda (2009-2014) that offers:
- High compression efficiency (similar to arithmetic coding)
- High speed (often faster, especially on modern CPUs)

Core Concept:
- Maps symbols to states (integers)
- Encodes these states

Variants:
1. tANS (tabled ANS)
   - Used for static probability distributions
   - Relies on precomputed tables

2. rANS (range ANS)
   - Flexible for adaptive probabilities
   - More suitable for dynamic Transformer outputs

Encoding Process (rANS):
- Distributes integers (states) across symbols based on probabilities
- Selects new state based on old state and symbol probability
- Encodes state into bits (typically LIFO order)

Decoding Process (rANS):
- Reads bits to reconstruct state
- Uses probability distribution and current state
- Determines which symbol was encoded

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset


Load Dataset

In [2]:
num_proc = -1 # CPUs go brrrr
ds = load_dataset('commavq/commavq.py', data_dir='./commavq', trust_remote_code=True)

In [3]:
# ds['i'] is the ith data shard
# ds['i'][x] is the x video (1 minute long / 1200 frames) in data shard 'i'
# ds['i'][x]['path'] is the path to the video file


tokens = []
examples = 1  # len(ds['0'])
for i in range(examples):
    t = np.load(ds['0'][i]['path'])
    t = torch.from_numpy(t)
    tokens.append(t)

tokens = torch.cat(tokens, dim=0)
tokens.shape


torch.Size([1200, 8, 16])

Create Static Probability Model 

In [5]:
# Calculate global token frequencies
unique_tokens, counts = torch.unique(tokens, return_counts=True) # unique_tokens is sorted

# Calculate probabilities
total_tokens = tokens.numel()
probabilities = counts.float() / total_tokens

# Create probability matrix [token_id, probability]
# Original token IDs are integers. Convert unique_tokens to float for stacking.
probability_matrix = torch.stack((unique_tokens.float(), probabilities), dim=1)

print(f"Total unique tokens: {probability_matrix.shape[0]}") # Use matrix shape
print(f"Total tokens in the dataset: {total_tokens}")
print("Probability model (token_id: probability):")
# For brevity, let's print the first 10 and last 10 if the model is large
# Create items list for printing, from the sorted probability_matrix
items = []
for i in range(probability_matrix.shape[0]):
    token_id = int(probability_matrix[i, 0].item()) # Convert token_id back to int for display
    prob = probability_matrix[i, 1].item()
    items.append((token_id, prob))

if len(items) > 20:
    for token_id, prob in items[:10]:
        print(f"  {token_id}: {prob:.6f}")
    print("  ...")
    for token_id, prob in items[-10:]:
        print(f"  {token_id}: {prob:.6f}")
else:
    for token_id, prob in items:
        print(f"  {token_id}: {prob:.6f}")

Total unique tokens: 1019
Total tokens in the dataset: 153600
Probability model (token_id: probability):
  0: 0.000742
  1: 0.000540
  2: 0.002103
  3: 0.000378
  4: 0.003542
  5: 0.000312
  6: 0.000560
  7: 0.000156
  8: 0.000384
  9: 0.000553
  ...
  1014: 0.001803
  1015: 0.000957
  1016: 0.001126
  1017: 0.001510
  1018: 0.001133
  1019: 0.000807
  1020: 0.000176
  1021: 0.001211
  1022: 0.000794
  1023: 0.000885
