In [93]:
"""
Pure Python implementation of rANS, by Jamie Townsend, to accompany the
tutorial paper on rANS at https://arxiv.org/abs/2001.09186. The same variable
names are used in this file as are used in the tutorial.

We use the names `push` and `pop` for encoding and decoding respectively. The
compressed message is a pair `m = (s, t)`, where `s` is an int in the range `[2
** (s_prec - t_prec), 2 ** s_prec)` and `t` is an immutable stack, implemented
using a cons list, containing ints in the range `[0, 2 ** t_prec)` (`prec` is
short for 'precision'). The precisions must satisfy

  t_prec < s_prec.

For convenient compatibility with C/Numpy types we use the settings
s_prec = 64 and t_prec = 32.

Both the `push` function and the `pop` function assume access to a probability
distribution over symbols. We use the name `x` for a symbol. To describe the
probability distribution we model the real interval [0, 1] with the range of
integers [0, 1, 2, ..., 2 ** p_prec]. Each symbol is represented by a
sub-interval within that range. This can be visualized for a probability
distribution over the set of symbols {'a', 'b', 'c', 'd'}:

    0                                                             1
    |                           |---- P(x)----|                   |
    |                                                             |
    |   'a'           'b'           x == 'c'           'd'        |
    |----------|----------------|-------------|-------------------|
    |                              ^                              |
    |------------ c ------------|--|--- p ----|                   |
    0                            s_bar                        2 ** p_prec

Each sub-interval can be represented by a pair of non-negative integers: `c`
and `p`. As shown in the above diagram, the number `p` represents the width
of the interval corresponding to `x`, so that

  P(x) == p / 2 ** p_prec

where P is the probability mass function of our distribution.

The number `c` represents the beginning of the interval corresponding to `x`,
and is analagous to the cumulative distribution function evaluated on `x`.

The model over symbols, which must be provided by the user, is specified by the
triple (f, g, p_prec). The function g does the mapping

  g: x |-> c, p

and the function f does the mapping

  f: s_bar |-> x, (c, p)

where s_bar is in {0, 1, ..., 2 ** p_prec - 1}. The values returned by f should
be the x, p and c corresponding to the sub-interval containing s_bar, as shown
in the diagram above.
"""
s_prec = 8
t_prec = 4
t_mask = (1 << t_prec) - 1
s_min  = 1 << s_prec - t_prec
s_max  = 1 << s_prec

#        s    , t
m_init = s_min, ()  # Shortest possible message

def rans(model):
    f, g, p_prec = model
    def push(m, x):
        print('push')
        print('m, x: ',m,x )
        s, t = m
        print('s, t: ', s,t)
        c, p = g(x)
        print('c, p: ', c,p)
        # Invert renorm
        while s >= p << s_prec - p_prec:
            s, t = s >> t_prec, (t, s & t_mask)
        print('after invert renorm: ', s,t)
        # Invert d
        print('inverting ')
        print('s // p: ',s // p )
        print('(s // p << p_prec): ', (s // p << p_prec))
        print('s % p: ', s % p)
        print('c: ',c)
        s = (s // p << p_prec) + s % p + c
        print('after invert d: ', s)
        assert s_min <= s < s_max
        
        return s, t

    def pop(m):
        print('pop')
        s, t = m
        print('s,m: ', s,m)
        # d(s)
        s_bar = s & ((1 << p_prec) - 1) # d1 = s mod 2^r
        print('s_bar or d1(s): ', s_bar)
        x, (c, p) = f(s_bar) # d2(s_bar)
        print('d2(s_bar): ', x)
        print('c,p: ', c,p)
        print('computing s"')
        print('s >> p_prec,s mod 2^r,c: ',s >> p_prec,s_bar,c)
        print('p * (s >> p_prec): ', p * (s >> p_prec))
        print('s_bar or s mod 2^r): ', s_bar)
        print('c: ',c)
        print('s": ',s)
        # Renormalize
        while s < s_min:
            t, t_top = t
            s = (s << t_prec) + t_top # 2^r_t * s + t_top
        assert s_min <= s < s_max
        print('renormalized (s",t): ', s,t)
        return (s, t), x
    return push, pop

def flatten_stack(t):
    flat = []
    while t:
        t, t_top = t
        flat.append(t_top)
    return flat

def unflatten_stack(flat):
    t = ()
    for t_top in reversed(flat):
        t = t, t_top
    return t

In [94]:
import math

log = math.log2

# We encode some data using the example model in the paper and verify the
# inequality in equation (20).

# First setup the model
p_prec = 3

# Cumulative probabilities
# c_j = sum_{k=1}^{j-1} p_k if j= 2,...I, else 0 (i.e. for j=1)
cs = {'a': 0,
      'b': 1,
      'c': 3,
      'd': 6}

# Probability weights, must sum to 2 ** p_prec
ps = {'a': 1,
      'b': 2,
      'c': 3,
      'd': 2}

# Backwards mapping

s_bar_to_x = {0: 'a',
              1: 'b', 2: 'b',
              3: 'c', 4: 'c', 5: 'c',
              6: 'd', 7: 'd'}

def f(s_bar):
    x = s_bar_to_x[s_bar]  # d_2 = a_i where i = max{j: c_j < s_bar}
    c, p = cs[x], ps[x]
    return x, (c, p)

def g(x):
    return cs[x], ps[x]

model = f, g, p_prec

push, pop = rans(model)

# Some data to compress
# xs = ['a', 'b', 'b', 'c', 'b', 'c', 'd', 'c', 'c']
xs = ['c', 'b', 'b',  'a', 'd','c']
# Compute h(xs):
h = sum(map(lambda x: log(2 ** p_prec / ps[x]), xs))
print('Information content of sequence: h(xs) = {:.2f} bits.'.format(h))
print()

# Initialize the message
m = m_init

# Encode the data
for x in xs:
    m = push(m, x)

Information content of sequence: h(xs) = 11.83 bits.

push
m, x:  (16, ()) c
s, t:  16 ()
c, p:  3 3
after invert renorm:  16 ()
inverting 
s // p:  5
(s // p << p_prec):  40
s % p:  1
c:  3
after invert d:  44
push
m, x:  (44, ()) b
s, t:  44 ()
c, p:  1 2
after invert renorm:  44 ()
inverting 
s // p:  22
(s // p << p_prec):  176
s % p:  0
c:  1
after invert d:  177
push
m, x:  (177, ()) b
s, t:  177 ()
c, p:  1 2
after invert renorm:  11 ((), 1)
inverting 
s // p:  5
(s // p << p_prec):  40
s % p:  1
c:  1
after invert d:  42
push
m, x:  (42, ((), 1)) a
s, t:  42 ((), 1)
c, p:  0 1
after invert renorm:  2 (((), 1), 10)
inverting 
s // p:  2
(s // p << p_prec):  16
s % p:  0
c:  0
after invert d:  16
push
m, x:  (16, (((), 1), 10)) d
s, t:  16 (((), 1), 10)
c, p:  6 2
after invert renorm:  16 (((), 1), 10)
inverting 
s // p:  8
(s // p << p_prec):  64
s % p:  0
c:  6
after invert d:  70
push
m, x:  (70, (((), 1), 10)) c
s, t:  70 (((), 1), 10)
c, p:  3 3
after invert renorm:  70 ((()

In [95]:
# Verify the inequality in eq (20)
eps = log(1 / (1 - 2 ** -(s_prec - p_prec - t_prec)))
print('eps = {:.2e}'.format(eps))
print()

s, t = m
lhs = log(s) + t_prec * len(flatten_stack(t)) - log(s_min)
rhs = h + len(xs) * eps
print('Eq (20) inequality, rhs - lhs == {:.2e}'.format(rhs - lhs))
print()

# Decode the message, check that the decoded data matches original
xs_decoded = []
for _ in range(len(xs)):
    m, x = pop(m)
    xs_decoded.append(x)

xs_decoded = reversed(xs_decoded)

for x_orig, x_new in zip(xs, xs_decoded):
    assert x_orig == x_new

# Check that the message has been returned to its original state
assert m == m_init
print('Decode successful!')

eps = 1.00e+00

Eq (20) inequality, rhs - lhs == 6.28e+00

pop
s,m:  188 (188, (((), 1), 10))
s_bar or d1(s):  4
d2(s_bar):  c
c,p:  3 3
computing s"
s >> p_prec,s mod 2^r,c:  23 4 3
p * (s >> p_prec):  69
s_bar or s mod 2^r):  4
c:  3
s":  188
renormalized (s",t):  188 (((), 1), 10)
pop
s,m:  188 (188, (((), 1), 10))
s_bar or d1(s):  4
d2(s_bar):  c
c,p:  3 3
computing s"
s >> p_prec,s mod 2^r,c:  23 4 3
p * (s >> p_prec):  69
s_bar or s mod 2^r):  4
c:  3
s":  188
renormalized (s",t):  188 (((), 1), 10)
pop
s,m:  188 (188, (((), 1), 10))
s_bar or d1(s):  4
d2(s_bar):  c
c,p:  3 3
computing s"
s >> p_prec,s mod 2^r,c:  23 4 3
p * (s >> p_prec):  69
s_bar or s mod 2^r):  4
c:  3
s":  188
renormalized (s",t):  188 (((), 1), 10)
pop
s,m:  188 (188, (((), 1), 10))
s_bar or d1(s):  4
d2(s_bar):  c
c,p:  3 3
computing s"
s >> p_prec,s mod 2^r,c:  23 4 3
p * (s >> p_prec):  69
s_bar or s mod 2^r):  4
c:  3
s":  188
renormalized (s",t):  188 (((), 1), 10)
pop
s,m:  188 (188, (((), 1), 10))
s_

AssertionError: 

In [90]:
import numpy as np
np.arange(48)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])

In [24]:
s_min

4294967296

In [92]:
11<<t_prec

176