# Py-tANS

In [1]:
#####
# Imports
#####
from collections import Counter
from math import floor, ceil
import random
import numpy as np
import matplotlib.pyplot as plt

tableLog = 5
tableSize = 1 << tableLog

In [2]:
# Return the Index of the First Non-Zero Bit.
def first1Index(val):
    counter = 0
    while val > 1:
        counter += 1
        val = val >> 1
    return counter

In [3]:
# Define how often a symbol is seen, total should equal the 
# table size.
symbol_occurrences = {"0":10,"1":10, "2":12}


In [4]:
####
# Define the Initial Positions of States in StateList.
####
symbol_list = [symbol for symbol,occcurences in symbol_occurrences.items()]
cumulative = [0 for _ in range(len(symbol_list)+2)]
for u in range(1, len(symbol_occurrences.items())+ 1):
    cumulative[u] = cumulative[u - 1] + list(symbol_occurrences.items())[u-1][1]
cumulative[-1] = tableSize +1

In [5]:
#####
# Spread Symbols to Create the States Table
#####
highThresh = tableSize - 1
stateTable = [0 for _ in range(tableSize)]
tableMask = tableSize - 1
step = ((tableSize >> 1) + (tableSize >> 3) + 3)
pos = 0
for symbol, occurrences in symbol_occurrences.items():
    for i in range(occurrences):
        stateTable[pos] = symbol
        pos = (pos + step) & tableMask
        while pos > highThresh:
            position = (pos + step) & tableMask
assert(pos == 0)
print(stateTable)

['0', '0', '1', '2', '2', '0', '1', '1', '2', '2', '0', '1', '2', '2', '0', '0', '1', '2', '2', '0', '1', '1', '2', '0', '0', '1', '2', '2', '0', '1', '1', '2']


In [6]:
#####
# Build Coding Table from State Table
#####
outputBits = [0 for _ in range(tableSize)]
codingTable = [0 for _ in range(tableSize)]
cumulative_cp = cumulative.copy()
for i in range(tableSize):
    s = stateTable[i]
    index = symbol_list.index(s)
    codingTable[cumulative_cp[index]] = tableSize + i
    cumulative_cp[index] += 1
    outputBits[i] = tableLog - first1Index(tableSize + i)

In [7]:
#####
# Create the Symbol Transformation Table
#####
total = 0
symbolTT = {}
for symbol, occurrences in symbol_occurrences.items():
    symbolTT[symbol] = {}
    if occurrences == 1:
        symbolTT[symbol]['deltaNbBits'] = (tableLog << 16) - (1 << tableLog)
        symbolTT[symbol]['deltaFindState'] = total - 1
    elif occurrences > 0:
        maxBitsOut = tableLog - first1Index(occurrences - 1)
        minStatePlus = occurrences << maxBitsOut
        symbolTT[symbol]['deltaNbBits'] = (maxBitsOut << 16) - minStatePlus
        symbolTT[symbol]['deltaFindState'] = total - occurrences
        total += occurrences
print(symbolTT)

{'0': {'deltaNbBits': 131032, 'deltaFindState': -10}, '1': {'deltaNbBits': 131032, 'deltaFindState': 0}, '2': {'deltaNbBits': 131024, 'deltaFindState': 8}}


In [8]:
# Output NbBits to a BitStream
def outputNbBits(state, nbBits):
    mask = (1 << nbBits) -1
    little = state & mask
    if nbBits >0:
        string = "{:b}".format(little)
    else:
        return ""
    while len(string) < nbBits:
        string = "0" + string
    return string

In [9]:
# Encode a Symbol Using tANS, giving the current state, the symbol, and the bitstream and STT
def encodeSymbol(symbol, state, bitStream, symbolTT):
    symbolTT = symbolTT[symbol]
    nbBitsOut = (state + symbolTT['deltaNbBits']) >> 16
    bitStream += outputNbBits(state,nbBitsOut)
    state = codingTable[(state >> nbBitsOut) + symbolTT['deltaFindState']]
    return state, bitStream

In [10]:
#####
# Generate a Decoding Table
#####
decodeTable = [{} for _ in range(tableSize)]
nextt = list(symbol_occurrences.items())
for i in range(tableSize):
    t = {}
    t['symbol'] = stateTable[i]
    index = symbol_list.index(t['symbol'])
    x = nextt[index][1]
    nextt[index] = (nextt[index][0], nextt[index][1] + 1)
    t['nbBits'] = tableLog - first1Index(x)
    t['newX'] = (x << t['nbBits']) - tableSize
    decodeTable[i] = t

In [11]:
# Convert Bits from Bitstream to the new State.
def bitsToState(bitStream, nbBits):
    bits = bitStream[-nbBits:]
    rest = int(bits,2)
    if nbBits == len(bitStream):
        remaining = ""
        return rest, remaining
    remaining = bitStream[:-nbBits]
    return rest, remaining

In [12]:
# Return a Symbol + New State + Bitstream from the bitStream and State.
def decodeSymbol(state, bitStream, stateT):
    symbol = stateT[state]['symbol']
    nbBits = stateT[state]['nbBits']
    rest, bitStream = bitsToState(bitStream, nbBits)
    state = stateT[state]['newX'] + rest
    return symbol, state, bitStream

In [13]:
# Split an Input String into a list of Symbols
def split(string):
    return [char for char in string]

In [22]:
#####
# Functions to Encode and Decode Streams of Data.
#####
def encodeData(inpu):
    bitStream = ""
    state, bitStream = encodeSymbol(inpu[0], 0, "", symbolTT)
    bitStream = ""
    for char in inpu:
        state, bitStream = encodeSymbol(char, state, bitStream, symbolTT)
    bitStream += outputNbBits(state - tableSize, tableLog) #Includes Current Bit
    return bitStream

def decodeData(bitStream):
    output = []
    state, bitStream = bitsToState(bitStream, tableLog)
    while len(bitStream) > 0:
        symbol, state, bitStream = decodeSymbol(state, bitStream, decodeTable)
        output = [symbol] + output
    return output

In [34]:
# Test Encoding
inpu = "1102010120"
bitStream = encodeData(inpu)

In [35]:
# Test Decoding
output = decodeData(bitStream)

In [36]:
# Assert that input and Output are the same
print(inpu, " = input")
print()
print(output, " = output")
assert(split(inpu) == output)

1102010120
['1', '1', '0', '2', '0', '1', '0', '1', '2', '0']
