In [1]:
"""
Closely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h by
Fabian Giesen.

We use the pythonic names `append` and `pop` for encoding and decoding
respectively. The compressed state is a pair `msg = (head, tail)`, where `head`
is an int in the range `[0, 2 ** head_precision)` and `tail` is an immutable
stack, implemented using a cons list, containing ints in the range
`[0, 2 ** tail_precision)`. The precisions must satisfy

  tail_precision < head_precision <= 2 * tail_precision.

For convenient compatibility with Numpy dtypes we use the settings
head_precision = 64 and tail_precision = 32.

Both the `append` method and the `pop` method assume access to a probability
distribution over symbols. We use the name `symb` for a symbol. To describe the
probability distribution we model the real interval [0, 1] with the range of
integers {0, 1, 2, ..., 2 ** precision}. Each symbol is represented by a
sub-interval within that range. This can be visualized for a probability
distribution over the set of symbols {a, b, c, d}:

    0                                                             1
    |          |----- P(symb) ------|                             |
    |                                                             |
    |    a           symb == b           c              d         |
    |----------|--------------------|---------|-------------------|
    |                                                             |
    |          |------ prob --------|                             |
    0        start                                            2 ** precision

Each sub-interval can be represented by a pair of non-negative integers:
`start` and `prob`. As shown in the above diagram, the number `prob` represents
the width of the interval, corresponding to `symb`, so that

  P(symb) = prob / 2 ** precision

where P is the probability mass function of our distribution.

The number `start` represents the beginning of the interval corresponding to
`symb`, which is analagous to the cumulative distribution function evaluated on
`symb`.
"""

'\nClosely based on https://github.com/rygorous/ryg_rans/blob/master/rans64.h by\nFabian Giesen.\n\nWe use the pythonic names `append` and `pop` for encoding and decoding\nrespectively. The compressed state is a pair `msg = (head, tail)`, where `head`\nis an int in the range `[0, 2 ** head_precision)` and `tail` is an immutable\nstack, implemented using a cons list, containing ints in the range\n`[0, 2 ** tail_precision)`. The precisions must satisfy\n\n  tail_precision < head_precision <= 2 * tail_precision.\n\nFor convenient compatibility with Numpy dtypes we use the settings\nhead_precision = 64 and tail_precision = 32.\n\nBoth the `append` method and the `pop` method assume access to a probability\ndistribution over symbols. We use the name `symb` for a symbol. To describe the\nprobability distribution we model the real interval [0, 1] with the range of\nintegers {0, 1, 2, ..., 2 ** precision}. Each symbol is represented by a\nsub-interval within that range. This can be visualized 

In [2]:
import numpy as np
from functools import reduce
import util

In [31]:
def append(msg, start, prob, precision):
    """
    Encodes a symbol with range `[start, start + prob)`.  All `prob`s are
    assumed to sum to `2 ** precision`. Compressed bits get written to `msg`.
    """
    # Prevent Numpy scalars leaking in
    start, prob, precision = map(int, [start, prob, precision])
    head, tail = msg
    print('width of prob distribution range in integer units: ',1<<precision)
    print('input head:', head)
    print('start, prob, precision: ',start, prob, precision )
    if head >= prob << head_precision - precision:
        print('prob << head_precision - precision = ',prob << head_precision - precision)
        # Need to push data down into tail
        print('new head', head >> tail_precision)
        print('new tail', (head & tail_mask, tail))
        head, tail = head >> tail_precision, (head & tail_mask, tail)
    print('appended head components, \nhead // prob << precision = {}, head % prob = {}, start = {}\nhead = {}   '.format(
          (head // prob << precision),
          head % prob,
          start,
          (head // prob << precision) + head % prob + start)
         )
    return (head // prob << precision) + head % prob + start, tail

def pop(msg, statfun, precision):
    """
    Pops a symbol from msg. The signiature of statfun should be
        statfun: cf |-> symb, (start, prob)
    where `cf` is in the interval `[start, start + prob)` and `symb` is the
    symbol corresponding to that interval.
    """
    # Prevent Numpy scalars leaking in
    precision = int(precision)
    head, tail = msg
    cf = head & ((1 << precision) - 1)
    symb, (start, prob) = statfun(cf)
    # Prevent Numpy scalars leaking in
    start, prob = int(start), int(prob)
    head = prob * (head >> precision) + cf - start
    if head < head_min:
        # Need to pull data up from tail
        head_new, tail = tail
        head = (head << tail_precision) + head_new
    return (head, tail), symb

def append_symbol(statfun, precision):
    def append_(msg, symbol):
        start, prob = statfun(symbol)
        return append(msg, start, prob, precision)
    return append_

def pop_symbol(statfun, precision):
    def pop_(msg):
        return pop(msg, statfun, precision)
    return pop_

def flatten(msg):
    """Flatten a rANS message into a 1d numpy array."""
    out, msg = [msg[0] >> 32, msg[0]], msg[1]
    while msg:
        x_head, msg = msg
        out.append(x_head)
    return np.asarray(out, dtype=np.uint32)

def unflatten(arr):
    """Unflatten a 1d numpy array into a rANS message."""
    #print(int(arr[0])<<32)
    #print(int(arr[1]))
    return (int(arr[0]) << 32 | int(arr[1]),
            reduce(lambda tl, hd: (int(hd), tl), reversed(arr[2:]), ()))

In [29]:
head_precision = 8
tail_precision = 4
tail_mask = (1 << tail_precision) - 1
head_min  = 1 << head_precision - tail_precision
print(tail_mask)
print(head_min)

15
16


In [30]:
#          head    , tail
msg_init = head_min, ()
msg_init

(16, ())

In [6]:
# other_bits = [1,2,2]
# state = unflatten(other_bits)
# print(state)    

In [28]:
uniform_enc_statfun = lambda s: (s, 1)
uniform_dec_statfun = lambda cf: (cf, (cf, 1))

In [27]:
state=append_symbol(uniform_enc_statfun,precision=4)(msg_init,4)
state

width of prob distribution in discrete units:  16
input head: 4
start, prob, precision:  4 1 4
prob << head_precision - precision =  1
new head 1
new tail (0, ())
appended head components, 
head // prob << precision = 16, head % prob = 0, start = 4
head = 20   


(20, (0, ()))

In [9]:
state=append_symbol(uniform_enc_statfun,precision=4)(state,4)
state

width of prob distribution in discrete units:  16
input head: 20
start, prob, precision:  4 1 4
prob << head_precision - precision =  16
new head 1
new tail (4, (0, ()))
appended head components, 
head // prob << precision = 16, head % prob = 0, start = 4
head = 20   


(20, (4, (0, ())))

In [10]:
state=append_symbol(uniform_enc_statfun,precision=head_precision)(state,1)
state

width of prob distribution in discrete units:  256
input head: 20
start, prob, precision:  1 1 8
prob << head_precision - precision =  1
new head 1
new tail (4, (4, (0, ())))
appended head components, 
head // prob << precision = 256, head % prob = 0, start = 1
head = 257   


(257, (4, (4, (0, ()))))

In [11]:
state=append_symbol(uniform_enc_statfun,precision=head_precision)(state,1)
state

width of prob distribution in discrete units:  256
input head: 257
start, prob, precision:  1 1 8
prob << head_precision - precision =  1
new head 16
new tail (1, (4, (4, (0, ()))))
appended head components, 
head // prob << precision = 4096, head % prob = 0, start = 1
head = 4097   


(4097, (1, (4, (4, (0, ())))))

In [12]:
state=append_symbol(uniform_enc_statfun,precision=head_precision)(state,3)
state

width of prob distribution in discrete units:  256
input head: 4097
start, prob, precision:  3 1 8
prob << head_precision - precision =  1
new head 256
new tail (1, (1, (4, (4, (0, ())))))
appended head components, 
head // prob << precision = 65536, head % prob = 0, start = 3
head = 65539   


(65539, (1, (1, (4, (4, (0, ()))))))

In [13]:
state=append_symbol(uniform_enc_statfun,precision=head_precision)(state,3)
state

width of prob distribution in discrete units:  256
input head: 65539
start, prob, precision:  3 1 8
prob << head_precision - precision =  1
new head 4096
new tail (3, (1, (1, (4, (4, (0, ()))))))
appended head components, 
head // prob << precision = 1048576, head % prob = 0, start = 3
head = 1048579   


(1048579, (3, (1, (1, (4, (4, (0, ())))))))

In [14]:
flatten(state)

array([      0, 1048579,       3,       1,       1,       4,       4,
             0], dtype=uint32)

In [15]:
state,symbol = pop_symbol(uniform_dec_statfun,precision=head_precision)(state)
state,symbol

((4096, (3, (1, (1, (4, (4, (0, ()))))))), 3)

In [16]:
state,symbol = pop_symbol(uniform_dec_statfun,precision=head_precision)(state)
state,symbol

((16, (3, (1, (1, (4, (4, (0, ()))))))), 0)

In [17]:
state,symbol = pop_symbol(uniform_dec_statfun,precision=head_precision)(state)
state,symbol

((3, (1, (1, (4, (4, (0, ())))))), 16)

In [18]:
state,symbol = pop_symbol(uniform_dec_statfun,precision=head_precision)(state)
state,symbol

((1, (1, (4, (4, (0, ()))))), 3)

In [23]:
np.shape([[2,3],[3,6],[3,8]])

(3, 2)